package main

import (
	"bytes"
	"flag"
	"fmt"
	"io/ioutil"
	"log"
	"os"
	"os/exec"
	"path"
	"path/filepath"
	"regexp"
	"strings"

	"gorgonia.org/cu"
)

var debug = flag.Bool("debug", false, "compile with debug mode (-linelinfo is added to nvcc call)")
var sameModule = flag.Bool("same-module", false, "generate a cudamodules.go file which can be placed in the gorgonia dir instead of the application main dir")

var funcNameRegex = regexp.MustCompile("// .globl	(.+?)\r?\n")

func stripExt(fullpath string) string {
	_, filename := filepath.Split(fullpath)
	ext := path.Ext(filename)
	return filename[:len(filename)-len(ext)]
}

func compileCUDA(src string, maj, min int) ([]byte, error) {
	target, err := ioutil.TempFile("", stripExt(src)+"_*.ptx")
	if err != nil {
		return nil, fmt.Errorf("failed to create temporary file for compilation output")
	}
	defer target.Close()

	output := fmt.Sprintf("-o=%v", target.Name())
	arch := fmt.Sprintf("-arch=compute_%d%d", maj, min)
	var cmd *exec.Cmd
	if *debug {
		cmd = exec.Command("nvcc", output, arch, "-lineinfo", "-ptx", "-Xptxas", "--allow-expensive-optimizations", "-fmad=false", "-ftz=false", "-prec-div=true", "-prec-sqrt=true", src)
	} else {
		cmd = exec.Command("nvcc", output, arch, "-ptx", "-Xptxas", "--allow-expensive-optimizations", "-fmad=false", "-ftz=false", "-prec-div=true", "-prec-sqrt=true", src)
	}
	var stderr bytes.Buffer
	cmd.Stderr = &stderr
	if err := cmd.Run(); err != nil || stderr.Len() != 0 {
		return nil, fmt.Errorf("failed to compile with nvcc. Error: %v. nvcc error: %v", err, stderr.String())
	}

	out, err := ioutil.ReadAll(target)
	if err != nil {
		return nil, fmt.Errorf("failed to read compilation output file. Error: %v", err)
	}
	if err := os.Remove(target.Name()); err != nil {
		log.Printf("could not remove temporary file %v", target.Name())
	}
	return out, nil
}

func packageLoc(name string) (string, error) {
	cmd := exec.Command("go", "list", "-f", "{{.Dir}}", "-find", name)
	var stdout, stderr bytes.Buffer
	cmd.Stdout = &stdout
	cmd.Stderr = &stderr
	if err := cmd.Run(); err != nil || stderr.Len() != 0 {
		return "", fmt.Errorf("failed to locate %v. Error: %v. go list error: %v", name, err, stderr.String())
	}
	return strings.TrimSpace(stdout.String()), nil
}

func packageInWorkingDir() (string, error) {
	cmd := exec.Command("go", "list", "-f", "{{.Name}}")
	var stdout, stderr bytes.Buffer
	cmd.Stdout = &stdout
	cmd.Stderr = &stderr
	if err := cmd.Run(); err != nil || stderr.Len() != 0 {
		return "", fmt.Errorf("failed to get name of package in working directory. Error: %v. go list error: %v", err, stderr.String())
	}
	return strings.TrimSpace(stdout.String()), nil
}

func gofmt(path string) error {
	cmd := exec.Command("gofmt", "-w", path)
	var stderr bytes.Buffer
	cmd.Stderr = &stderr
	if err := cmd.Run(); err != nil {
		return fmt.Errorf("go imports failed with %v for %q. Error: %v", err, path, stderr.String())
	}
	return nil
}

func main() {
	flag.Parse()

	var devices int
	var err error
	if devices, err = cu.NumDevices(); err != nil {
		log.Fatalf("error while finding number of devices: %+v", err)
	}
	if devices == 0 {
		log.Fatal("No CUDA-capable devices found")
	}

	// Get the lowest possible compute capability
	major := int(^uint(0) >> 1)
	minor := int(^uint(0) >> 1)
	for d := 0; d < devices; d++ {
		var dev cu.Device
		if dev, err = cu.GetDevice(d); err != nil {
			log.Fatalf("Unable to get GPU%d - %+v", d, err)
		}

		maj, min, err := dev.ComputeCapability()
		if err != nil {
			log.Fatalf("Unable to get compute compatibility of GPU%d - %v", d, err)
		}
		if maj > 0 && maj < major {
			major = maj
			minor = min
			continue
		}

		if min > 0 && min < minor {
			minor = min
		}
	}

	cwd, err := os.Getwd()
	if err != nil {
		log.Fatal(err)
	}
	cudamodules := path.Join(cwd, "cudamodules.go")
	packageName, err := packageInWorkingDir()
	if err != nil {
		log.Fatal(err)
	}

	gorgoniaLoc, err := packageLoc("gorgonia.org/gorgonia")
	if err != nil {
		log.Fatal(err)
	}
	cuLoc := path.Join(gorgoniaLoc, "cuda modules", "src", "*.cu")

	matches, err := filepath.Glob(cuLoc)
	if err != nil {
		log.Fatal(err)
	}

	m := make(map[string][]byte)
	funcs := make(map[string][]string)
	for _, match := range matches {
		name := stripExt(match)
		data, err := compileCUDA(match, major, minor)
		if err != nil {
			log.Fatal(err)
		}
		m[name] = data

		// Regex
		var fns []string
		matches := funcNameRegex.FindAllSubmatch(data, -1)
		for _, bs := range matches {
			fns = append(fns, string(bs[1]))
		}
		funcs[name] = fns
	}

	var buf bytes.Buffer
	header := fmt.Sprintf(`// Code generated by Gorgonia cudagen. DO NOT EDIT.
// +build cuda

package %v
`, packageName)
	buf.WriteString(header)
	if ! *sameModule {
		buf.WriteString("import \"gorgonia.org/gorgonia\"\n")
	}

	buf.WriteString("func init() {\n")
	for name := range m {
		if ! *sameModule {
			buf.WriteString("gorgonia.")
		}
		buf.WriteString(fmt.Sprintf("AddToStdLib(%q, %sPTX, []string{\"%s\"})\n", name, name, strings.Join(funcs[name], "\", \"")))
	}
	buf.WriteString("}\n")

	for name, data := range m {
		buf.WriteString(fmt.Sprintf("const %vPTX = `", name))
		buf.Write(data)
		buf.WriteString("`\n")
	}

	f, err := os.OpenFile(cudamodules, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
	if err != nil {
		log.Fatal(err)
	}
	defer f.Close()
	if _, err = buf.WriteTo(f); err != nil {
		log.Fatalf("unable to write output to %v", cudamodules)
	}

	if err = gofmt(cudamodules); err != nil {
		log.Fatal(err)
	}

	fmt.Printf("Created %v\n", cudamodules)
}
